import torch
from entity_extractor import *
from visualizer import *
import ast
import pandas as pd
import pickle
import math
import sys

''' 
# Setup instructions for IPython shell exploration:
# 1. Highlight the following commands and press alt+shift+e
%run entity_extractor.py
%reload_ext autoreload
%autoreload 2
# 2. Highlight the code section you would like to examine and press alt+shift+e
'''

# Flag to use command line arguments for base directory
USE_CMDLINE_BASE_DIR = True

# Base directory where input files are loaded and output files are saved
BASE_DIR = "[PATH_TO_BASE_DIR]"
# Input file of posts
INPUT_FILE_NAME = "[PATH_TO_INPUT_FILE]"
# Dataset name
DATASET = "[DATASET_NAME]"  
# Relationship file name
RELATIONSHIPS_RAW_NAME = "[PATH_TO_RELATIONSHIPS_FILE]"

## If this is the first time running this code, the following files will be generated
# Processed relationship file name which contains NER tags
DF_EXTRACTION_NAME = "[EXTRACTION_FILE_NAME]"
# Name of the dataframe (and csv file) which contains ranking of the named entities
DF_NER_RANKING_NAME = "[NER_RANKING_FILE_NAME]"
# Name of the dataframe (and csv file) which contains ranking of the argument headwords
DF_ARG_RANKING_NAME = "[ARG_RANKING_FILE_NAME]"

# Set to True to regenerate NER tags and flair sentences (to save computation time if already generated)
REGENERATE_DF_EXTRACTIONS_WITH_NER_FLAIR_SENTENCES_AND_TAGS = True

# Set to True to overwrite the existing NER ranking
OVERWRITE_NER_RANKING = True

# Number of entities to include in clustering
NUMBER_OF_ENTITIES_TO_CLUSTER = 50  # 130 was used for the paper


def generate_avg_embeddings_for_entities_tensorboard_format(base_dir, file_postfix):
    """
    Generate average embeddings for entities in TensorBoard format.
    
    Args:
        base_dir (str): Base directory path
        file_postfix (str): File name postfix
    """
    file_prefix = "[ENTITY_EMBEDDING_PREFIX]"
    file_name = file_prefix + file_postfix + ".pkl"
    path_to_file = base_dir + file_name

    with open(path_to_file, "rb") as f:
        ent_embeddings = pickle.load(f)

    print(ent_embeddings)

    embs = []
    ent_names = []
    ent_types = []
    for k, v in ent_embeddings.items():
        if not isinstance(v["embedding"], float):  # Check for valid embeddings
            embs.append(v["embedding"])
            ent_names.append(k)
            ent_types.append(v["type"])
    
    df_embs = pd.DataFrame(embs)
    df_embs.to_csv(base_dir + file_postfix + "[EMBEDDINGS_OUTPUT_SUFFIX]", sep='\t', index=False, header=False)
    
    df_ent_names = pd.DataFrame(ent_names, columns=["Text"])
    df_ent_names["Type"] = ent_types
    df_ent_names.to_csv(base_dir + file_postfix + "[METADATA_OUTPUT_SUFFIX]", sep='\t', index=False, header=True)
    print(embs)


def get_entity_versions(dataset="[DEFAULT_DATASET]"):
    """
    Get different versions/variants of entity names for a specific dataset.
    
    Args:
        dataset (str): Name of the dataset
        
    Returns:
        dict: Dictionary mapping canonical entity names to their variants
    """
    entity_versions = defaultdict(list)

    if dataset == "[DATASET_TYPE_1]":
        entity_versions['[ENTITY_1]'] = ['[VARIANT_1A]', '[VARIANT_1B]', '[VARIANT_1C]']
        entity_versions['[ENTITY_2]'] = ['[VARIANT_2A]', '[VARIANT_2B]']
        entity_versions['[ENTITY_3]'] = ['[VARIANT_3A]', '[VARIANT_3B]', '[VARIANT_3C]', '[VARIANT_3D]', '[VARIANT_3E]']
        # Additional entities would be defined similarly
        
    if dataset == "[DATASET_TYPE_2]":
        entity_versions['[ENTITY_1]'] = ['[FULL_NAME_1]']
        entity_versions['[ENTITY_2]'] = ['[FULL_NAME_2]']
        entity_versions['[ENTITY_3]'] = ['[FULL_NAME_3]']
        # Additional entities would be defined similarly

    # Convert all entity names to lowercase
    for ent_glob_name, ent_version_list in entity_versions.items():
        for ind in range(len(ent_version_list)):
            ent_version_list[ind] = ent_version_list[ind].lower()

    return entity_versions


def generate_entity_versions_automatically(base_dir, entity_version_pickle_name="[ENTITY_VERSION_FILE]", min_freq=-1, overwrite=False):
    """
    Automatically generate entity versions from ranking data.
    
    Args:
        base_dir (str): Base directory path
        entity_version_pickle_name (str): Name for the pickle file to save entity versions
        min_freq (int): Minimum frequency for entities to be included (-1 means no minimum)
        overwrite (bool): Whether to overwrite existing file
        
    Returns:
        dict: Dictionary of entity versions
    """
    entity_versions = defaultdict(list)
    path_to_file = base_dir + entity_version_pickle_name.split(".")[0] + "_" + str(min_freq) + ".pkl"
    
    if os.path.isfile(path_to_file) and not overwrite:
        with open(path_to_file, "rb") as f:
            entity_versions_auto_generated = pickle.load(f)
        return entity_versions_auto_generated
    else:
        df_ent_final_ranking = pd.read_csv(base_dir + "[ENTITY_RANKING_FILE]")
        df_persons = df_ent_final_ranking[df_ent_final_ranking["type"] == "PERSON"]
        
        # Drop entities which are either article writers or noise from NER pipeline
        entities_to_drop = ["[NOISE_ENTITY_1]", "[NOISE_ENTITY_2]"]
        df_persons = df_persons[~df_persons["entity"].isin(entities_to_drop)]
        df_persons = df_persons.dropna()  # Drop NER mistakes
        
        if min_freq != -1:
            df_persons = df_persons[df_persons["frequency_score_sum_NER_arg"] >= min_freq]

        def get_cap_fullnames_only(ent):
            """Get capitalized full names only."""
            print(ent, type(ent))
            ent_cap = ""
            if len(ent.split(" ")) == 2:
                ent_cap = ent_capitalized(ent)
            return ent_cap
            
        df_persons["fullnames"] = df_persons.apply(lambda x: get_cap_fullnames_only(x["entity"]), axis=1)
        df_persons = df_persons[df_persons["fullnames"] != ""]
        df_persons["entity"] = df_persons["fullnames"]
        df_persons.drop(columns=["fullnames"], inplace=True)
        df_persons.to_csv(base_dir + "[PERSONS_OUTPUT_PREFIX]" + str(min_freq) + ".csv")
        
        for ind, row in df_persons.iterrows():
            ent = row["entity"]
            if len(ent.split(" ")) == 2:
                ent_cap = ent_capitalized(ent)
                if ent_cap not in entity_versions:
                    entity_versions[ent_cap] = [ent]
                    
        with open(path_to_file, "wb") as f:
            pickle.dump(entity_versions, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    # Convert all entity names to lowercase
    for ent_glob_name, ent_version_list in entity_versions.items():
        for ind in range(len(ent_version_list)):
            ent_version_list[ind] = ent_version_list[ind].lower()

    print("Number of total entities: ", len(entity_versions.keys()))
    return entity_versions


def ent_capitalized(ent):
    """
    Capitalize each word in an entity name.
    
    Args:
        ent (str): Entity name
        
    Returns:
        str: Capitalized entity name
    """
    ent_splitted = ent.split(" ")
    res = ""
    for w in ent_splitted:
        res += w[0].upper() + w[1:] + " "
    return res.strip()


def get_entity_versions_reverse_mapping(dataset="bridgegate_entity_trends"):
    """
    Get reverse mapping from entity variants to canonical names.
    
    Args:
        dataset (str): Name of the dataset
        
    Returns:
        dict: Dictionary mapping variants to canonical entity names
    """
    entity_versions = get_entity_versions(dataset)
    entity_versions_reverse_mapping = defaultdict(list)

    for ent_glob_name, ent_version_list in entity_versions.iteritems():
        for ind in range(len(ent_version_list)):
            entity_versions_reverse_mapping[ent_version_list[ind]] = ent_glob_name

    return entity_versions_reverse_mapping


def _get_main_persons(base_dir, entity_min_freq):
    """
    Compare entity list with an entity list provided by a news page.
    Used for experimentation, results not used in paper.
    
    Args:
        base_dir (str): Base directory path
        entity_min_freq (int): Minimum frequency for entities
    """
    path_to_reference = base_dir + "[REFERENCE_TEXT_FILE]"
    import codecs
    f = codecs.open(path_to_reference, "r", "utf-8")
    txt_reference_list = f.readlines()
    txt_reference = " ".join(txt_reference_list)
    txt_reference = txt_reference.lower()
    df_persons = pd.read_csv(base_dir + "[PERSONS_OUTPUT_PREFIX]" + str(entity_min_freq) + ".csv")

    df_persons["exist_in_reference"] = df_persons.apply(lambda x: x["entity"].lower() in txt_reference, axis=1)
    df_persons.to_csv(base_dir + "[PERSONS_REFERENCE_OUTPUT_PREFIX]" + str(entity_min_freq) + ".csv")


def experiment_visualize_first_mention_of_entities(base_dir, input_file_name, create_new_ents_dict=False):
    """
    Experiment to visualize first mentions of entities.
    
    Args:
        base_dir (str): Base directory path
        input_file_name (str): Input file name
        create_new_ents_dict (bool): Whether to create a new entities dictionary
    """
    vis = visualizer(base_dir)
    entity_min_freq = -1    
    new_ent_dict_name = "[NEW_ENTITIES_DICT_PREFIX]" + str(entity_min_freq) + ".pkl"
    path_to_new_ent_dict = base_dir + new_ent_dict_name

    entity_versions = generate_entity_versions_automatically(
        base_dir, 
        entity_version_pickle_name="entity_versions_auto_generated.pkl",
        min_freq=entity_min_freq, 
        overwrite=True
    )

    if create_new_ents_dict:
        vis.create_first_mention_of_entities_dict(
            input_file_name, 
            entity_versions=entity_versions, 
            output_name=new_ent_dict_name, 
            generate_df_with_dates=False
        )

    vis.visualize_new_ents_dict(path_to_new_ent_dict, output_post_fix_name="minFreq_"+str(entity_min_freq))


def experiment_main_generate_ent_rankings(
    base_dir,
    df_extraction_raw_name="[DEFAULT_EXTRACTION_RAW_NAME]",
    df_extraction_name="[DEFAULT_EXTRACTION_NAME]",
    df_ner_ranking_name="[DEFAULT_NER_RANKING_NAME]",
    df_arg_ranking_name="[DEFAULT_ARG_RANKING_NAME]",
    dataset_name="[DEFAULT_DATASET_NAME]",
    regenerate_df_extractions_with_ner_flair_sentences_and_tags=False,
    overwrite_ner_ranking=False
):
    """
    Main experiment to generate entity rankings.
    
    Args:
        base_dir (str): Base directory path
        df_extraction_raw_name (str): Raw extraction file name
        df_extraction_name (str): Extraction file name with NER
        df_ner_ranking_name (str): NER ranking file name
        df_arg_ranking_name (str): Argument ranking file name
        dataset_name (str): Dataset name
        regenerate_df_extractions_with_ner_flair_sentences_and_tags (bool): Whether to regenerate extractions
        overwrite_ner_ranking (bool): Whether to overwrite existing NER ranking
    """
    ee = EntityExtractor(
        base_dir,
        df_extraction_raw_name,
        df_extraction_name,
        df_ner_ranking_name,
        df_arg_ranking_name,
        dataset_name,
        regenerate_df_extractions_with_ner_flair_sentences_and_tags,
        overwrite_ner_ranking
    )

    df_ent_final_ranking = ee.generate_or_load_final_ent_ranking(
        path_to_file=base_dir + "[ENTITY_FINAL_RANKING_FILE]", 
        overwrite=overwrite_ner_ranking
    )
    print(df_ent_final_ranking.head())

    start_time = time.time()
    ent_emb_lists = ee.get_ent_emb_dict(df_ent_final_ranking, only_top_N_entitis=NUMBER_OF_ENTITIES_TO_CLUSTER)
    print("Entity embedding generation done - execution time: ", (time.time()-start_time)/60.0)
    print("Entity lists:", ent_emb_lists.keys())
    
    ent_single_emb_lists = {}
    for ent_name, ent_cnt_and_emb in ent_emb_lists.items():
        ent_single_emb_lists[ent_name] = {}
        ent_single_emb_lists[ent_name]["type"] = ent_cnt_and_emb["type"]
        ent_single_emb_lists[ent_name]["count"] = ent_cnt_and_emb["count"]
        ent_single_emb_lists[ent_name]["embedding"] = np.mean(ent_cnt_and_emb["embeddings"], axis=0)

    print(ent_single_emb_lists)
    PIK = base_dir + "[ENTITY_EMBEDDING_OUTPUT_PREFIX]" + str(NUMBER_OF_ENTITIES_TO_CLUSTER) + ".pkl"
    print("Saving pickle object at: ", PIK)
    with open(PIK, "wb") as f:
        pickle.dump(ent_single_emb_lists, f, protocol=pickle.HIGHEST_PROTOCOL)


if __name__ == '__main__':
    nltk.download('averaged_perceptron_tagger')
    nltk.download('stopwords')
    
    if USE_CMDLINE_BASE_DIR:
        base_dir = str(sys.argv[1])
        input_file_name = str(sys.argv[2])
        rel_raw_name = str(sys.argv[3]) 
    else:
        base_dir = BASE_DIR
        input_file_name = INPUT_FILE_NAME
        rel_raw_name = RELATIONSHIPS_RAW_NAME

    experiment_main_generate_ent_rankings(
        base_dir,
        df_extraction_raw_name=rel_raw_name,
        df_extraction_name=DF_EXTRACTION_NAME,
        df_ner_ranking_name=DF_NER_RANKING_NAME,
        df_arg_ranking_name=DF_ARG_RANKING_NAME,
        dataset_name=DATASET,
        regenerate_df_extractions_with_ner_flair_sentences_and_tags=REGENERATE_DF_EXTRACTIONS_WITH_NER_FLAIR_SENTENCES_AND_TAGS,
        overwrite_ner_ranking=OVERWRITE_NER_RANKING
    )

    experiment_visualize_first_mention_of_entities(
        base_dir, 
        input_file_name=input_file_name, 
        create_new_ents_dict=True
    )
    
    # Generate embeddings for visualizing the entities in TensorBoard
    file_postfix = "[FILE_POSTFIX_PREFIX]" + str(NUMBER_OF_ENTITIES_TO_CLUSTER)
    generate_avg_embeddings_for_entities_tensorboard_format(base_dir=base_dir, file_postfix=file_postfix)

    # Visualize entities into 2D plot using PCA projection
    vis = visualizer(base_dir)
    path_to_entity_emb_dict = base_dir + "[ENTITY_EMBEDDING_OUTPUT_PREFIX]" + str(NUMBER_OF_ENTITIES_TO_CLUSTER) + ".pkl"
    vis.visualize_clusters(path_to_entity_emb_dict, output_file_name="[CLUSTER_VISUALIZATION_OUTPUT]")
